Skip to content

[HLSL][DXIL] Implement refract intrinsic #147342

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: main
Choose a base branch
from

Conversation

raoanag
Copy link

@raoanag raoanag commented Jul 7, 2025

  • Implement refract using HLSL source in hlsl_intrinsics.h
  • Implement the refract SPIR-V target built-in in clang/include/clang/Basic/BuiltinsSPIRV.td
  • Add sema checks for refract to CheckSPIRVBuiltinFunctionCall in clang/lib/Sema/SemaSPIRV.cpp
  • Add codegen for spv refract to EmitSPIRVBuiltinExpr in CGBuiltin.cpp
  • Add codegen tests to clang/test/CodeGenHLSL/builtins/refract.hlsl
  • Add spv codegen test to clang/test/CodeGenSPIRV/Builtins/refract.c
  • Add sema tests to clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl
  • Add spv sema tests to clang/test/SemaSPIRV/BuiltIns/refract-errors.c
  • Create the int_spv_refract intrinsic in IntrinsicsSPIRV.td
  • In SPIRVInstructionSelector.cpp create the refract lowering and map it to int_spv_refract in SPIRVInstructionSelector::selectIntrinsic.
  • Create SPIR-V backend test case in llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll
  • Check for what OpenCL support is needed.

Resolves #99153

Copy link

github-actions bot commented Jul 7, 2025

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@raoanag raoanag marked this pull request as ready for review July 7, 2025 16:27
Copy link

github-actions bot commented Jul 7, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@llvmbot
Copy link
Member

llvmbot commented Jul 7, 2025

@llvm/pr-subscribers-hlsl
@llvm/pr-subscribers-backend-directx

@llvm/pr-subscribers-backend-spir-v

Author: None (raoanag)

Changes
  • Implement refract using HLSL source in hlsl_intrinsics.h
  • Implement the refract SPIR-V target built-in in clang/include/clang/Basic/BuiltinsSPIRV.td
  • Add sema checks for refract to CheckSPIRVBuiltinFunctionCall in clang/lib/Sema/SemaSPIRV.cpp
  • Add codegen for spv refract to EmitSPIRVBuiltinExpr in CGBuiltin.cpp
  • Add codegen tests to clang/test/CodeGenHLSL/builtins/refract.hlsl
  • Add spv codegen test to clang/test/CodeGenSPIRV/Builtins/refract.c
  • Add sema tests to clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl
  • Add spv sema tests to clang/test/SemaSPIRV/BuiltIns/refract-errors.c
  • Create the int_spv_refract intrinsic in IntrinsicsSPIRV.td
  • In SPIRVInstructionSelector.cpp create the refract lowering and map it to int_spv_refract in SPIRVInstructionSelector::selectIntrinsic.
  • Create SPIR-V backend test case in llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll
  • Check for what OpenCL support is needed.

Resolves #99153


Patch is 57.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147342.diff

19 Files Affected:

  • (modified) clang/include/clang/Basic/BuiltinsSPIRVVK.td (+1)
  • (modified) clang/include/clang/Sema/Sema.h (+24)
  • (modified) clang/lib/CodeGen/TargetBuiltins/SPIR.cpp (+15)
  • (modified) clang/lib/Headers/hlsl/hlsl_detail.h (+8)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h (+36)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+59)
  • (modified) clang/lib/Sema/SemaChecking.cpp (+105)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+64-9)
  • (modified) clang/lib/Sema/SemaSPIRV.cpp (+36-56)
  • (modified) clang/test/CodeGenHLSL/builtins/reflect.hlsl (+1-1)
  • (added) clang/test/CodeGenHLSL/builtins/refract.hlsl (+271)
  • (added) clang/test/CodeGenSPIRV/Builtins/refract.c (+29)
  • (added) clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl (+74)
  • (added) clang/test/SemaSPIRV/BuiltIns/refract-errors.c (+23)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+1)
  • (modified) llvm/lib/IR/IRBuilder.cpp (+1-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+2)
  • (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll (+36)
  • (added) llvm/test/CodeGen/SPIRV/opencl/refract-error.ll (+12)
diff --git a/clang/include/clang/Basic/BuiltinsSPIRVVK.td b/clang/include/clang/Basic/BuiltinsSPIRVVK.td
index 61cc0343c415e..5dc3c7588cd2a 100644
--- a/clang/include/clang/Basic/BuiltinsSPIRVVK.td
+++ b/clang/include/clang/Basic/BuiltinsSPIRVVK.td
@@ -11,3 +11,4 @@ include "clang/Basic/BuiltinsSPIRVBase.td"
 
 def reflect : SPIRVBuiltin<"void(...)", [NoThrow, Const]>;
 def faceforward : SPIRVBuiltin<"void(...)", [NoThrow, Const, CustomTypeChecking]>;
+def refract : SPIRVBuiltin<"void(...)", [NoThrow, Const, CustomTypeChecking]>;
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 3fe26f950ad51..105ab804fffd0 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -2791,6 +2791,30 @@ class Sema final : public SemaBase {
 
   void CheckConstrainedAuto(const AutoType *AutoT, SourceLocation Loc);
 
+  /// CheckVectorArgs - Check that the arguments of a vector function call
+  bool CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck);
+
+  bool CheckVectorArgs(CallExpr *TheCall);
+
+  bool CheckAllArgTypesAreCorrect(
+      Sema *S, CallExpr *TheCall,
+      llvm::ArrayRef<
+          llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>>
+          Checks);
+  bool CheckAllArgTypesAreCorrect(
+      Sema *S, CallExpr *TheCall,
+      llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> Check);
+
+  static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
+                                            int ArgOrdinal,
+                                            clang::QualType PassedType);
+  static bool CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
+                                             int ArgOrdinal,
+                                             clang::QualType PassedType);
+
+  static bool CheckFloatOrHalfScalarRepresentation(Sema *S, SourceLocation Loc,
+                                                int ArgOrdinal,
+                                                clang::QualType PassedType);
   /// BuiltinConstantArg - Handle a check if argument ArgNum of CallExpr
   /// TheCall is a constant expression.
   bool BuiltinConstantArg(CallExpr *TheCall, int ArgNum, llvm::APSInt &Result);
diff --git a/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp b/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
index 0687485cd3f80..1c63e04f757c7 100644
--- a/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
@@ -58,6 +58,21 @@ Value *CodeGenFunction::EmitSPIRVBuiltinExpr(unsigned BuiltinID,
         /*ReturnType=*/I->getType(), Intrinsic::spv_reflect,
         ArrayRef<Value *>{I, N}, nullptr, "spv.reflect");
   }
+  case SPIRV::BI__builtin_spirv_refract: {
+    Value *I = EmitScalarExpr(E->getArg(0));
+    Value *N = EmitScalarExpr(E->getArg(1));
+    Value *eta = EmitScalarExpr(E->getArg(2));
+    assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
+           E->getArg(1)->getType()->hasFloatingRepresentation() &&
+           E->getArg(2)->getType()->isFloatingType() &&
+           "refract operands must have a float representation");
+    assert(E->getArg(0)->getType()->isVectorType() &&
+           E->getArg(1)->getType()->isVectorType() &&
+           "refract I and N operands must be a vector");
+    return Builder.CreateIntrinsic(
+        /*ReturnType=*/I->getType(), Intrinsic::spv_refract,
+        ArrayRef<Value *>{I, N, eta}, nullptr, "spv.refract");
+  }
   case SPIRV::BI__builtin_spirv_smoothstep: {
     Value *Min = EmitScalarExpr(E->getArg(0));
     Value *Max = EmitScalarExpr(E->getArg(1));
diff --git a/clang/lib/Headers/hlsl/hlsl_detail.h b/clang/lib/Headers/hlsl/hlsl_detail.h
index 80c4900121dfb..96e101a1e3aa8 100644
--- a/clang/lib/Headers/hlsl/hlsl_detail.h
+++ b/clang/lib/Headers/hlsl/hlsl_detail.h
@@ -45,6 +45,14 @@ template <typename T> struct is_arithmetic {
   static const bool Value = __is_arithmetic(T);
 };
 
+template <typename T> struct is_vector {
+  static const bool value = false;
+};
+
+template <typename T, int N> struct is_vector<vector<T, N>> {
+  static const bool value = true;
+};
+
 template <typename T, int N>
 using HLSL_FIXED_VECTOR =
     vector<__detail::enable_if_t<(N > 1 && N <= 4), T>, N>;
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
index 4eb7b8f45c85a..f6acb1cea2594 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
@@ -71,6 +71,42 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
 #endif
 }
 
+template <typename T> constexpr T refract_impl(T I, T N, T Eta) {
+  T Mul = N * I;
+  T K = 1 - Eta * Eta * (1 - (Mul * Mul));
+  T Result = (Eta * I - (Eta * Mul + sqrt(K)) * N);
+  return select<T>(K < 0, static_cast<T>(0), Result);
+}
+
+template <typename T, typename U>
+constexpr T refract_vec_impl(T I, T N, U Eta) {
+#if (__has_builtin(__builtin_spirv_refract))
+  if (is_vector<T>::value) {
+    return __builtin_spirv_refract(I, N, Eta);
+  }
+#else
+  T Mul = dot(N, I);
+  T K = 1 - Eta * Eta * (1 - Mul * Mul);
+  T Result = (Eta * I - (Eta * Mul + sqrt(K)) * N);
+  return select<T>(K < 0, static_cast<T>(0), Result);
+#endif
+}
+
+/*
+template <typename T, int L>
+constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T Eta) {
+#if (__has_builtin(__builtin_spirv_refract) && is_vector<T>))
+  return __builtin_spirv_refract(I, N, Eta);
+#else
+  T Mul = dot(N, I);
+  vector<T, L> K = 1 - Eta * Eta * (1 - Mul * Mul);
+  vector<T, L> Result = (Eta * I - (Eta * Mul + sqrt(K)) * N);
+  return select<vector<T, L>>(K < 0, vector<T, L>(0), Result);
+#endif
+}
+
+*/
+
 template <typename T> constexpr T fmod_impl(T X, T Y) {
 #if !defined(__DIRECTX__)
   return __builtin_elementwise_fmod(X, Y);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index ea880105fac3b..8c262ffce25f1 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -475,6 +475,65 @@ reflect(__detail::HLSL_FIXED_VECTOR<float, L> I,
   return __detail::reflect_vec_impl(I, N);
 }
 
+//===----------------------------------------------------------------------===//
+// refract builtin
+//===----------------------------------------------------------------------===//
+
+/// \fn T refract(T I, T N, T eta)
+/// \brief Returns a refraction using an entering ray, \a I, a surface
+/// normal, \a N and refraction index \a eta
+/// \param I The entering ray.
+/// \param N The surface normal.
+/// \param eta The refraction index.
+///
+/// The return value is a floating-point vector that represents the refraction
+/// using the refraction index, \a eta, for the direction of the entering ray,
+/// \a I, off a surface with the normal \a N.
+///
+/// This function calculates the refraction vector using the following formulas:
+/// k = 1.0 - eta * eta * (1.0 - dot(N, I) * dot(N, I))
+/// if k < 0.0 the result is 0.0
+/// otherwise, the result is eta * I - (eta * dot(N, I) + sqrt(k)) * N
+///
+/// I and N must already be normalized in order to achieve the desired result.
+///
+/// I and N must be a scalar or vector whose component type is
+/// floating-point.
+///
+/// eta must be a 16-bit or 32-bit floating-point scalar.
+///
+/// Result type, the type of I, and the type of N must all be the same type.
+
+template <typename T>
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+const inline __detail::enable_if_t<__detail::is_arithmetic<T>::Value &&
+                                       __detail::is_same<half, T>::value,
+                                   T> refract(T I, T N, T eta) {
+  return __detail::refract_impl(I, N, eta);
+}
+
+template <typename T>
+const inline __detail::enable_if_t<
+    __detail::is_arithmetic<T>::Value && __detail::is_same<float, T>::value, T>
+refract(T I, T N, T eta) {
+  return __detail::refract_impl(I, N, eta);
+}
+
+template <int L>
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+const inline __detail::HLSL_FIXED_VECTOR<half, L> refract(
+    __detail::HLSL_FIXED_VECTOR<half, L> I,
+    __detail::HLSL_FIXED_VECTOR<half, L> N, half eta) {
+  return __detail::refract_vec_impl(I, N, eta);
+}
+
+template <int L>
+const inline __detail::HLSL_FIXED_VECTOR<float, L>
+refract(__detail::HLSL_FIXED_VECTOR<float, L> I,
+        __detail::HLSL_FIXED_VECTOR<float, L> N, float eta) {
+  return __detail::refract_vec_impl(I, N, eta);
+}
+
 //===----------------------------------------------------------------------===//
 // smoothstep builtin
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index dd5b710d7e1d4..98bca59f14ecd 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -16151,3 +16151,108 @@ void Sema::CheckTCBEnforcement(const SourceLocation CallExprLoc,
     }
   }
 }
+
+bool Sema::CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck) {
+  for (unsigned i = 0; i < NumArgsToCheck; ++i) {
+    ExprResult Arg = TheCall->getArg(i);
+    QualType ArgTy = Arg.get()->getType();
+    auto *VTy = ArgTy->getAs<VectorType>();
+    if (VTy == nullptr) {
+      SemaRef.Diag(Arg.get()->getBeginLoc(),
+                   diag::err_typecheck_convert_incompatible)
+          << ArgTy
+          << SemaRef.Context.getVectorType(ArgTy, 2, VectorKind::Generic) << 1
+          << 0 << 0;
+      return true;
+    }
+  }
+  return false;
+}
+
+bool Sema::CheckVectorArgs(CallExpr *TheCall) {
+  return CheckVectorArgs(TheCall, TheCall->getNumArgs());
+}
+
+
+bool Sema::CheckAllArgTypesAreCorrect(
+    Sema *S, CallExpr *TheCall,
+    llvm::ArrayRef<
+        llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>>
+        Checks) {
+  unsigned NumArgs = TheCall->getNumArgs();
+  if (Checks.size() == 1) {
+    // Apply the single check to all arguments
+    for (unsigned I = 0; I < NumArgs; ++I) {
+      Expr *Arg = TheCall->getArg(I);
+      if (Checks[0](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
+        return true;
+    }
+    return false;
+  } else if (Checks.size() == NumArgs) {
+    // Apply each check to the corresponding argument
+    for (unsigned I = 0; I < NumArgs; ++I) {
+      Expr *Arg = TheCall->getArg(I);
+      if (Checks[I](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
+        return true;
+    }
+    return false;
+  } else {
+    // Mismatch: error or fallback
+    S->Diag(TheCall->getBeginLoc(), diag::err_builtin_invalid_arg_type)
+        << NumArgs << Checks.size();
+    return true;
+  }
+}
+
+bool Sema::CheckAllArgTypesAreCorrect(
+    Sema *S, CallExpr *TheCall,
+    llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> Check) {
+  return CheckAllArgTypesAreCorrect(S, TheCall, llvm::ArrayRef{Check});
+}
+
+bool Sema::CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
+                                           int ArgOrdinal,
+                                           clang::QualType PassedType) {
+  clang::QualType BaseType =
+      PassedType->isVectorType()
+          ? PassedType->castAs<clang::VectorType>()->getElementType()
+          : PassedType;
+  if (!BaseType->isHalfType() && !BaseType->isFloat32Type())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
+           << /* half or float */ 2 << PassedType;
+  return false;
+}
+
+bool Sema::CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
+                                                  int ArgOrdinal,
+                                                  clang::QualType PassedType) {
+  const auto *VecTy = PassedType->getAs<VectorType>();
+
+  clang::QualType BaseType =
+      PassedType->isVectorType()
+          ? PassedType->castAs<clang::VectorType>()->getElementType()
+          : PassedType;
+  if (!VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* vector of */ 5 << /* no int */ 0
+           << /* half or float */ 2 << PassedType;
+  return false;
+}
+
+bool Sema::CheckFloatOrHalfScalarRepresentation(
+    Sema *S, SourceLocation Loc,
+                                                 int ArgOrdinal,
+                                                 clang::QualType PassedType) {
+  const auto *VecTy = PassedType->getAs<VectorType>();
+
+  clang::QualType BaseType =
+      PassedType->isVectorType()
+          ? PassedType->castAs<clang::VectorType>()->getElementType()
+          : PassedType;
+  if (VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
+           << /* half or float */ 2 << PassedType;
+  return false;
+}
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index bad357b50929b..991d330edfb6f 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2401,17 +2401,40 @@ static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType) {
   return false;
 }
 
-static bool CheckAllArgTypesAreCorrect(
+bool CheckAllArgTypesAreCorrect(
     Sema *S, CallExpr *TheCall,
-    llvm::function_ref<bool(Sema *S, SourceLocation Loc, int ArgOrdinal,
-                            clang::QualType PassedType)>
-        Check) {
-  for (unsigned I = 0; I < TheCall->getNumArgs(); ++I) {
-    Expr *Arg = TheCall->getArg(I);
-    if (Check(S, Arg->getBeginLoc(), I + 1, Arg->getType()))
-      return true;
+    llvm::ArrayRef<
+        llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>>
+        Checks) {
+  unsigned NumArgs = TheCall->getNumArgs();
+  if (Checks.size() == 1) {
+    // Apply the single check to all arguments
+    for (unsigned I = 0; I < NumArgs; ++I) {
+      Expr *Arg = TheCall->getArg(I);
+      if (Checks[0](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
+        return true;
+    }
+    return false;
+  } else if (Checks.size() == NumArgs) {
+    // Apply each check to the corresponding argument
+    for (unsigned I = 0; I < NumArgs; ++I) {
+      Expr *Arg = TheCall->getArg(I);
+      if (Checks[I](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
+        return true;
+    }
+    return false;
+  } else {
+    // Mismatch: error or fallback
+    S->Diag(TheCall->getBeginLoc(), diag::err_builtin_invalid_arg_type)
+        << NumArgs << Checks.size();
+    return true;
   }
-  return false;
+}
+
+bool CheckAllArgTypesAreCorrect(
+    Sema *S, CallExpr *TheCall,
+    llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> Check) {
+  return CheckAllArgTypesAreCorrect(S, TheCall, llvm::ArrayRef{Check});
 }
 
 static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
@@ -2428,6 +2451,38 @@ static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
   return false;
 }
 
+static bool CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
+                                           int ArgOrdinal,
+                                           clang::QualType PassedType) {
+  const auto *VecTy = PassedType->getAs<VectorType>();
+
+  clang::QualType BaseType = 
+      PassedType->isVectorType()
+        ? PassedType->castAs<clang::VectorType>()->getElementType()
+          : PassedType;
+  if (!VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* vector of */ 5 << /* no int */ 0
+           << /* half or float */ 2 << PassedType;
+  return false;
+}
+
+static bool CheckFloatOrHalfScalarRepresentation(Sema *S, SourceLocation Loc,
+                                                 int ArgOrdinal,
+                                                 clang::QualType PassedType) {
+  const auto *VecTy = PassedType->getAs<VectorType>();
+
+  clang::QualType BaseType =
+      PassedType->isVectorType()
+          ? PassedType->castAs<clang::VectorType>()->getElementType()
+          : PassedType;
+  if (VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
+           << /* half or float */ 2 << PassedType;
+  return false;
+}
+
 static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
                                   unsigned ArgIndex) {
   auto *Arg = TheCall->getArg(ArgIndex);
diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp
index c27d3fed2b990..1b4093065a63a 100644
--- a/clang/lib/Sema/SemaSPIRV.cpp
+++ b/clang/lib/Sema/SemaSPIRV.cpp
@@ -157,81 +157,61 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
 
-    ExprResult A = TheCall->getArg(0);
-    QualType ArgTyA = A.get()->getType();
-    auto *VTyA = ArgTyA->getAs<VectorType>();
-    if (VTyA == nullptr) {
-      SemaRef.Diag(A.get()->getBeginLoc(),
-                   diag::err_typecheck_convert_incompatible)
-          << ArgTyA
-          << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
-          << 0 << 0;
+    // Use the helper function to check both arguments
+    if (SemaRef.CheckVectorArgs(TheCall))
       return true;
-    }
 
-    ExprResult B = TheCall->getArg(1);
-    QualType ArgTyB = B.get()->getType();
-    auto *VTyB = ArgTyB->getAs<VectorType>();
-    if (VTyB == nullptr) {
-      SemaRef.Diag(A.get()->getBeginLoc(),
-                   diag::err_typecheck_convert_incompatible)
-          << ArgTyB
-          << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
-          << 0 << 0;
-      return true;
-    }
-
-    QualType RetTy = VTyA->getElementType();
+    QualType RetTy =
+        TheCall->getArg(0)->getType()->getAs<VectorType>()->getElementType();
     TheCall->setType(RetTy);
     break;
   }
   case SPIRV::BI__builtin_spirv_length: {
     if (SemaRef.checkArgCount(TheCall, 1))
       return true;
-    ExprResult A = TheCall->getArg(0);
-    QualType ArgTyA = A.get()->getType();
-    auto *VTy = ArgTyA->getAs<VectorType>();
-    if (VTy == nullptr) {
-      SemaRef.Diag(A.get()->getBeginLoc(),
-                   diag::err_typecheck_convert_incompatible)
-          << ArgTyA
-          << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
-          << 0 << 0;
+
+    // Use the helper function to check the argument
+    if (SemaRef.CheckVectorArgs(TheCall))
       return true;
-    }
-    QualType RetTy = VTy->getElementType();
+
+    QualType RetTy =
+        TheCall->getArg(0)->getType()->getAs<VectorType>()->getElementType();
     TheCall->setType(RetTy);
     break;
   }
-  case SPIRV::BI__builtin_spirv_reflect: {
-    if (SemaRef.checkArgCount(TheCall, 2))
+  case SPIRV::BI__builtin_spirv_refract: {
+    if (SemaRef.checkArgCount(TheCall, 3))
       return true;
 
-    ExprResult A = TheCall->getArg(0);
-    QualType ArgTyA = A.get()->getType();
-    auto *VTyA = ArgTyA->getAs<VectorType>();
-    if (VTyA == nullptr) {
-      SemaRef.Diag(A.get()->getBeginLoc(),
-                   diag::err_typecheck_convert_incompatible)
-          << ArgTyA
-          << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
-          << 0 << 0;
+    llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>
+        ChecksArr[] = {Sema::CheckFloatOrHalfVectorsRepresentation,
+                       Sema::CheckFloatOrHalfVectorsRepresentation,
+                       Sema::CheckFloatOrHalfScalarRepresentation};
+    if (SemaRef.CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                           llvm::ArrayRef(ChecksArr)))
       return true;
-    }
 
-    ExprResult B = TheCall->getArg(1);
-    QualType ArgTyB = B.get()->getType();
-    auto *VTyB = ArgTyB->getAs<VectorType>();
-    if (VTyB == nullptr) {
-      SemaRef.Diag(A.get()->getBeginLoc(),
-                   diag::err_typecheck_convert_incompatible)
-          << ArgTyB
-          << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
-          << 0 << 0;
+    ExprResult C = TheCall->getArg(2);
+    QualType ArgTyC = C.get()->getType();
+    if (!ArgTyC->isFloatingType...
[truncated]

Copy link
Member

@farzonl farzonl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR isn't ready yet.

@raoanag raoanag force-pushed the user/raoanag/refract branch from 515ecda to 729fbf3 Compare July 8, 2025 23:03
return true;

llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>
ChecksArr[] = {CheckFloatOrHalfRepresentation,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused on if this is meant to handle scalar values as well as vectors? Looking at the code gen, we are asserting that the first two arguments are vectors, but here we allow them to be scalars. @farzonl Does this handle only the case where the first two arguments are vectors?

If that is the case 'CheckFloatOrHalfRepresentation' should be updated to only check for vectors of half or float and should probably be renamed to 'CheckFloatOrHalfVecRepresentation'.

Copy link
Author

@raoanag raoanag Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to support vector of size 1, which is implicitly converted to scalar.
Also HLSL_FIXED_VECTOR only supports Vector of N > 1.

Hence, even though first 2 args are described as vector for N = 1 they are seen as scalar

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think your explanation is slightly incorrect, but it does seem the __builtin_spirv_refract is reachable with a scalar value. In this case the codegen assertions are wrong and I will leave a comment there about updating them.

Copy link
Member

@farzonl farzonl Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ll review this tomorrow, but supporting scalars here seems wrong. I’m almost 100% sure that spirv via dxc only supports vectors and that our semantics should match that.

Copy link
Contributor

@spall spall Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Refract spirv op says both scalar and vector are supported. https://registry.khronos.org/SPIR-V/specs/unified1/GLSL.std.450.pdf (search for Refract).
But it is up to us what we want to allow.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

referring the khornos doc, SPIR asserts for vectorness for all the intrinsics - distance, length, reflect, smoothstep, faceforward would need to be updated since they all mention operands must all be a scalar or vector whose component type is floating-point.. SPIR.cpp

Looking into how E->getArg(0)->getType()->isVectorType()

isVectorType() checks for vector-ness, not size.
• Vectors of size 1 are technically allowed

Just sharing observations here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we had some issues with spirv-val trying to do scalars with those other spirv opcodes. Also i’m pretty sure dxc doesn’t call the glsl instruction in many of the scalar cases. We have been using dxcs behavior as our default spec since we don’t have one.

@@ -0,0 +1,36 @@
; RUN: llc -O0 -mtriple=spirv-unknown-vulkan %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
Copy link
Member

@farzonl farzonl Jul 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no spirv64 target if you want to use the glsl extensions. Change the target for the second run line

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement the refract HLSL Function
5 participants